{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import init, gluon, nd, autograd\n",
    "from mxnet.gluon import nn\n",
    "import numpy as np\n",
    "import pickle as p\n",
    "import matplotlib.pyplot as plt\n",
    "from time import time\n",
    "%matplotlib inline\n",
    "ctx = mx.gpu()\n",
    "data_dir = '/home/sinyer/python/data/cifar10'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_cifar(route = data_dir+'/cifar-10-batches-py'):\n",
    "    def load_batch(filename):\n",
    "        with open(filename, 'rb')as f:\n",
    "            data_dict = p.load(f, encoding='latin1')\n",
    "            X = data_dict['data']\n",
    "            Y = data_dict['labels']\n",
    "            X = X.reshape(10000, 3, 32,32).astype(\"float\")\n",
    "            Y = np.array(Y)\n",
    "            return X, Y\n",
    "    def load_labels(filename):\n",
    "        with open(filename, 'rb') as f:\n",
    "            label_names = p.load(f, encoding='latin1')\n",
    "            names = label_names['label_names']\n",
    "            return names\n",
    "    label_names = load_labels(route + \"/batches.meta\")\n",
    "    x1, y1 = load_batch(route + \"/data_batch_1\")\n",
    "    x2, y2 = load_batch(route + \"/data_batch_2\")\n",
    "    x3, y3 = load_batch(route + \"/data_batch_3\")\n",
    "    x4, y4 = load_batch(route + \"/data_batch_4\")\n",
    "    x5, y5 = load_batch(route + \"/data_batch_5\")\n",
    "    test_pic, test_label = load_batch(route + \"/test_batch\")\n",
    "    train_pic = np.concatenate((x1, x2, x3, x4, x5))\n",
    "    train_label = np.concatenate((y1, y2, y3, y4, y5))\n",
    "    return train_pic, train_label, test_pic, test_label\n",
    "\n",
    "def accuracy(output, label):\n",
    "    return nd.mean(output.argmax(axis=1)==label).asscalar()\n",
    "\n",
    "def evaluate_accuracy(test_data, net, ctx):\n",
    "    acc = 0.\n",
    "    for data, label in test_data:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        output = net(data)\n",
    "        acc = acc + accuracy(output, label)\n",
    "    return acc / len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Residual(nn.Block):\n",
    "    def __init__(self, channels, same_shape=True, **kwargs):\n",
    "        super(Residual, self).__init__(**kwargs)\n",
    "        self.same_shape = same_shape\n",
    "        with self.name_scope():\n",
    "            strides = 1 if same_shape else 2\n",
    "            self.bn1 = nn.BatchNorm()\n",
    "            self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1, strides=strides)\n",
    "            self.bn2 = nn.BatchNorm()\n",
    "            self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)\n",
    "            if not same_shape:\n",
    "                self.conv3 = nn.Conv2D(channels, kernel_size=1, strides=strides)\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(nd.relu(self.bn1(x)))\n",
    "        out = self.conv2(nd.relu(self.bn2(out)))\n",
    "        if not self.same_shape:\n",
    "            x = self.conv3(x)\n",
    "        return out + x\n",
    "\n",
    "class ResNet(nn.Block):\n",
    "    def __init__(self, num_classes, **kwargs):\n",
    "        super(ResNet, self).__init__(**kwargs)\n",
    "        with self.name_scope(): \n",
    "            b1 = nn.Conv2D(16, kernel_size=3, strides=1, padding=1)\n",
    "            b2 = nn.Sequential()\n",
    "            for _ in range(8):\n",
    "                b2.add(Residual(16))\n",
    "            b3 = nn.Sequential()\n",
    "            b3.add(Residual(32, same_shape=False))\n",
    "            for _ in range(7):\n",
    "                b3.add(Residual(32))\n",
    "            b4 = nn.Sequential()\n",
    "            b4.add(Residual(64, same_shape=False))\n",
    "            for _ in range(7):\n",
    "                b4.add(Residual(64))\n",
    "            b5 = nn.Sequential()\n",
    "            b5.add(nn.BatchNorm(),nn.Activation(activation=\"relu\"),nn.AvgPool2D(pool_size=8),\n",
    "                   nn.Dense(num_classes))\n",
    "            self.net = nn.Sequential()\n",
    "            self.net.add(b1, b2, b3, b4, b5)\n",
    "    def forward(self, x):\n",
    "        out = x\n",
    "        for b in self.net:\n",
    "            out = b(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_pic, train_label, test_pic, test_label = load_cifar()\n",
    "\n",
    "batch_size = 128\n",
    "train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    train_pic.astype('float32')/255, train_label.astype('float32')), batch_size, shuffle=True)\n",
    "test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    test_pic.astype('float32')/255, test_label.astype('float32')), batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = ResNet(10)\n",
    "net.initialize(ctx=ctx, init=init.Xavier())\n",
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()\n",
    "trainer = gluon.Trainer(net.collect_params(), 'nag', {'learning_rate': 0.05, 'momentum': 0.9, 'wd': 5e-4})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 loss:1.4943 tracc:0.4432 teacc:0.4644 time:30.037\n",
      "10 loss:0.3551 tracc:0.8771 teacc:0.7657 time:30.141\n",
      "20 loss:0.2575 tracc:0.9104 teacc:0.8035 time:30.160\n",
      "30 loss:0.2126 tracc:0.9260 teacc:0.8007 time:30.219\n",
      "40 loss:0.0705 tracc:0.9790 teacc:0.8833 time:30.221\n",
      "50 loss:0.0028 tracc:1.0000 teacc:0.8839 time:30.255\n",
      "tracc:1.000000 teacc:0.879648\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzt3Xl8XHW9//HXJ/vaNG3ShaYrLbSl\ntGBLWSqyFSiLIHpBEL1XUKsXUdwvXL0ueH8ucBW5iihetQpcEQGhF9CCrAoUulAo3dMlNFuTtM2+\nz3x/f3wnbZomzaSddDKn7+fjMY+Zc+abM9+TTN7zne/5nu8x5xwiIhIsSfGugIiIxJ7CXUQkgBTu\nIiIBpHAXEQkghbuISAAp3EVEAqjfcDez35hZlZm908fzZmb/bWbFZva2mb0n9tUUEZGBiKblvgRY\ndIjnLwGmRW6LgXuPvFoiInIk+g1359zLwJ5DFLkS+L3zlgPDzWxsrCooIiIDlxKDbYwDdnZbLo2s\nq+hZ0MwW41v3ZGdnz50+fXoMXl5EDkco7Kisb6W1I0RnyNEZdoR1xvpRMW54JiOy0w7rZ1etWlXj\nnCvsr1wswt16WdfrO8Q5dx9wH8C8efPcypUrY/DyIjJQ75TVcdODq2mobeH8KSMYmZ3OyJw0CnLS\nGZGdRm5GCmnJSaSl+Ft6ShIpSUmkJBspSUkkJxkpSUaS+X9/F/mXd673f37nHGb7o6LrUdcqi6yx\nHmnSfbnr562P5wfCub5/tuvzze1bPrDufW/THVDHQ/3IsIxUstMPL37NrCSacrEI91JgfLflIqA8\nBtsVkRhzzvHA8hK+++QGRuak8cdPn8ncifnxrpYMgliE+1LgZjN7CDgdqHPOHdQlIyLx1dDawa2P\nreWptys498RCfnzNKYfdNSBDX7/hbmZ/AM4FCsysFPgWkArgnPsF8DRwKVAMNAM3DFZlReTw1Ld2\n8IF7XqFkdzNfW3Qin3nf8SQlHWafhiSEfsPdOXddP8874LMxq5GIxNyvXt7Gtuom7v/EfM6e1u+x\nOAkAnaEqEnDVDW38+h/buWz2WAX7MUThLhJw97xQTFtnmC9feEK8qyJHkcJdJMBK9zbzv6+/y9Vz\ni5hSmBPv6shRpHAXCbCf/G0LGNyycFq8qyJHmcJdJKC27GrgsdWl/PMZExmblxnv6shRpnAXCaj/\nemYTWWkp3HTe1HhXReJA4S4SQGt21rJs3S4+efZknah0jIrFGaoiEmfhyKRfIedwDu5ctpER2Wl8\n8uwp8a6axInCXeQoaWrrpKKuhbbOMLB/giqAvc3tVNS1squulYr6VirrWnHOUZibzqjcDEYNS6cw\nJ52kJKNkdxPba5rZUdNEye4mKupb6W0yx/+4fCY5hzk5lSQ+/eVFumlpD1Hb0s7epg5qm9tp7QwB\n+2ctxKCtI0xNY9v+W0M79a0dpKckkZmWTGZqCplpfhbFqoZWSve2ULq3hT1N7VHVYUR2GmOGZZCU\nBOsr6qlpbCcUdgeVmTgyizOmjGTs8AxSk5NINiMpMlPjyJw0PnjquJj+biSxKNwlITjneHdPM8u3\n7eb1bXsor2th+phhnDwuj1nj8ji+MJuU5CQ6QmG2VTexoaKeDRX1bKlqpK0zRCjsCIch5Hz3RUco\nTHtnt1soTENr575WdbTys1IpyEknLzOVxrZOmveGaGkP0doRor0zTOGwdIrys5g1Lo+i/EzGDc8k\nIzV53893ze6Sl5nK2LxMRg1LP+B58POu72lqp7qhjc5wmIkjs8nLTD3SX6kEnMJdhoyOUHhfS7im\nsY3qSMt4c2UDy7ftobK+FYCR2WkUjcjijyt2suTVHQBkpCZRlJ/Fu7ubaQ/5gE5LTmJKYTbZ6SmR\nVi2kJiWRZEZqctK++cpTI/e5GSkMz0pleGYa+VmpDM9KIzNtf9C6SN9HanIShbl+3vPU5MEfk5Cc\nZBTmplOYmz7oryXBoXCXw1LX3EFlfStZacnkZqSQnZ6yL+icc7R1hqlv6aC+tZOG1g4a2zppauuk\nsS1EY2sHTe0hqhvaqKhrobKulYq6Vqob23rtOy7ISef0KSM4Y8pIzpg8gqmjcjAzQmHH9ppG1pbV\nsba0np17m7lg+ihmjB3GjLHDmFKYfVTCV2QoUrjLIe1tamdLVSNbqhrYsmv/fVVD20Fl01OSyEhN\nprm9k45Q/5dry81IYWxeBmPyMpk+Zhhj8vyBw4IcfyvMSacgN42stN7fpslJxtRRuUwdlctVpx7x\nrooEisL9GNPaEWJbdRPF1Y1srWqktSOEmZGcBMlmmBm7m9rYsquR4qpGdnc7CJiVlsy0UTmcPa2Q\naaNzOG54Jq0dIRpbO/e1zFs6QmSlpTAsM4XcjFSGZaSQm5FCTnoqOekp5KSnkJ2eTHZ6ykF9yyIS\nOwr3gGrrDLG1qonNuxrYWNnA5l0NbKlqoHRvy76uDzPf2g6HOWCM9LCMFKaNzmXhjNFMG53D8aNy\nmFqYw7jhmbrAg0iCULgHRHltC29s38Pr23ezcsdettU07Rs+l5psHF+Ywynj8/nQe4qYOiqHqaNy\nmDQy+6DWc9dBw2guCCwiQ5fCPQE1tHawsbKB9eX1vF1axxs7drNzTwvg+7HnTczn4pPGcOKYXKaP\nyWVSQfQHFhXqIsGgcB/iqhpaWVdWz9qyOtaV17GhooF39zTve35EdhqnTcrnhrMmM3/yCGaMHUay\nuk6kN+EQdLRAZ6u/dbRCZwu0NUJr3YG3UDskJfubRe6TUiApFZJTIDnNP05Jg9Ss/be0LP9cZxt0\nNEN7U+Q1W8CS9v9c1zaS0yE1A1Iit9TM/Y+TemmQhMPQ3ghtDeBCkDMaUg4xRLSjFZpr/Gt11bG3\n7QaQwn2ICIX9STobKurZWFHPunIf6N1HpUwuyObkcXl8+LTxzBiby4yxwxgzLEOtbelf8x746Vxo\n2RPvmkRvX+BnAc5/CLU3HFwuMx9yx0LuGMjIg6YaaNzlb611vW83NdN/2Liwv4Uj9ylpkJYL6TmQ\nlgNp2f6DrbMt8iHV5j+owqHIh1Oa/5nkNP/h1/N/0ZIgu9DXLXes/zDKHQuF0yF75KD82vbt5qBu\nXfrU1NbJa1t38/KWat4qrWNzZQMtHf5U9ySD4wtzeO/UAk4al8fJ4/KYMTaX3AydlSiHqeQVH+xn\n3gzDJ/rWbmqmv0/P9aGYMTxyGwbJqZHAC0G404dZuNPfQh2+ZR/u3B96HZEWenuTf64rlNOyIq+T\nGQnRjsjPR7YRaj8wNLu+TXS0+pZ/Z+QeID3P17XrBtBYBY2V0FAJDRVQ+y5kFfjwnHIu5Izyy+HO\nyDeJZn/ftU1LitySfTB3tu3/ZtD9G0JKpv+97PtWkXzgPoTa/XJP4U7Yu8PXr7N1//rLfgSnfXIQ\n/+AK96PGOcfmXY28sKmKlzZVs7JkDx0hR1ZaMnOKhnPd/AlMH5vLzLHDmDoqR8MEJbZKXvNdIBd8\n89DdGN0lJQFJPujlyDgHrbWRD6FKGDn4c+wr3AdZWW0LT6wp44k3y9m0y3+lnD4mlxvfO5lzphUy\nd1I+6SkKchlk774KRfOiD3aJLTPffZSZD6NmHJWXVLgPgvrWDp58q4LH15Txxnbfxzl3Yj7fvfIk\nLjppDKOHZcS5hnJMaWuEirfhvV+Md03kKFK4x4hzjuXb9vCnlTt5+p0KWjvCHF+YzZcvPIErTxnH\nhJFZ8a6iHKtK3/D9xhPPjHdN5ChSuB+hupYOHlhewsMrd1Kyu5nc9BQ+9J4irpk3ntlFeRrJIvFX\n8po/aFg0P941kaNI4X6YwmHHI6tK+eFfN7K7qZ3TJ4/glgumccmssQdMEysSd+++BmNO9qM95Jih\ncD8Ma3bW8q2l63hrZy1zJ+bzuxvnM2tcXryrJXKwznYoXQFzb4h3TeQoU7gPwJ6mdn7wlw08vLKU\nwtx0fnzNHK46dZy6XmToqljjx1erv/2Yo3CP0pqdtdz0wCqqGtpY/L4pfO78qTqpSIa+d1/z9xMU\n7scahXs/nHM8+Pq7fOf/1jF6WAZ/vmkBJxepC0YSRMlr/oSZnFHxrokcZQr3Q2hpD/H1x9fy2Ooy\nzjmhkLuvPYXhWWnxrpZIdMJh33Kf8f5410TiQOHeh5LdTXz6/lVs2tXALRdM45YLpulCFZJYqjf6\nU94nnhXvmkgcKNx78dbOWm5YsoJQ2PGbj5/GeSfqK60MUZuXwcYnYdEP/SRd3b37qr+fcMbRr5fE\nncK9hxc3VfGvD6xmZE4av79xPlMKc+JdJZGDhcPw8p3w4vf8cnI6XPZfB5YpeQ1yxkD+5KNfP4m7\nqGatN7NFZrbJzIrN7NZenp9gZi+Y2Ztm9raZXRr7qg6+x1aX8snfrWRyQTaP3XSWgn0wdV3INaj2\nbIO/3gY7V8R+26318MeP+mCffS2c9ilY8SvY8rf9ZZzz/e0Tzzx4jnE5JvTbcjezZOAe4EKgFFhh\nZkudc+u7FfsG8LBz7l4zmwk8DUwahPoOCuccv3x5Gz/4y0bOOn4kv/zYXA1zHKhQJ2z+C2x9Ac69\n9dCjM+rK4P6r4NSPwoLPH3q7L3wPqjfB1UsGJ6QaKuHtP8LaP/n5uAtPhMIZMGq6nxN8xPH+YgzR\ncg5WLYFlX/dznC//OUy/3E+1W3jikde3Zgs89BHYvdV3xZz+aT8H+Y5/wBOfhZteg6wRfl7z+jKY\noP72Y1U03TLzgWLn3DYAM3sIuBLoHu4O6Dq3OQ8oj2UlB9sdyzZx74tbuXz2WH50zZzBm4LXudgH\nVHuTv1pMvDTthtW/g5W/gbqdft3ON+DjT0Lm8IPLt9bBg1dDzSZ4/j9h5hWQP6n3bVeu9V0PLuzD\na/LZsalzRwtsfAre+gNsfd5vv+g0yBvvX3P9UvxbGn91nRFTfDAXnOgDf9QMf0vq8T5pqIQnbobi\nZ2HyOXDpnbD+CXjlv2HTGTDnI3Debf4iFqUr/ZmjpSug/E3/u5r4Xpi0ACYugPyJfpuhDthdDFXr\nYdd6eOM+P7/6Pz+x//eRmgEfvA9+dT783y1wze/3j2/XyUvHLHP9fD02s38CFjnnPhlZ/hhwunPu\n5m5lxgLPAPlANrDQObeql20tBhYDTJgwYW5JSUms9uOw/fWdCj7zwGqumz+B//eBWbEbEVNfDuVr\noGqd/6esWu9bU+/7Kpz9pSPffqjD/yO/8yjc8DSMm3vk2+xNOAx//y/Y8H/+aj1dc1Jn5vvLmL3z\nGITaYPL7YP6n/Xzhf7jOzx3+0ccOPMgX6vDBvuPvcMXP4Kkvw/HnwbUPHvy6zsFvL4GazX7SqzEn\nw8f+PLC6N1T638/eEmgoh/oKv66x0l8hZ1gRzLkW5lwHBd0untDR4l+3aqP/EKre5Eee7NnuZ1cE\nfwm2ce/xHwpF86GtHv7yNX8FoQtv91fZ6bpWZ1MN/P3HvuskHNq/DUuCUSfBuFP9h2TJK350C/gP\nmvRc31IPd+wvP+EsuOoXMHz8wfv7j7vgb9+GD/zCh/u6x+Hfth/8ISQJzcxWOefm9VsuinC/Gri4\nR7jPd859rluZL0W29SMzOxP4NTDLORfua7vz5s1zK1eujG5vBklZbQuX/ORlJhVk88hnziItJUYX\nzn13OSy5fP8/5fAJ/p+4vdEH20cehhMuPvzttzfBw//iW4hpuTDsOPj0y74FFw3nYPNf4bnbYfRJ\ncMkd/qt8Tx2t8MRNPiDHn+4vRdayB1r2+mtyJqf5cJy/2HdjdHnnMXjkRph2kQ/u5FT/mk/cDGse\ngCt/Dqde7wPvue/40D7+/ANf++2H4bFPwfv/G5p3+3KLX4TjTu1//8pWwfJfwLo/+79Bep6/huWw\nsZB7nL+ffA5MOntgF0vubPPdIZVvR1reb0DlO/vDetxcuOqXUDCt95+vfRdW/Np/MBbNg7Gn+Gt1\ndgmHfSOg5FUo+YefF2bUdBg1039TGDnt0H/jcMi/7yrX+knCRp8E1/8p+v2ThBDLcD8T+LZz7uLI\n8m0AzrnvdyuzDt+63xlZ3gac4Zyr6mu78Q73zlCY6361nPXl9Tz1+bOZVBCjro1wCO471wfS1Uv8\nP2XX9R47WuDXF0FtiQ+qEVMGvv2mGt/6rVgDl9/lW3gPfBDO+hxc9J/9/3z1Zlh2GxT/zV9Ls77M\nX8D3yp/B1IXdXme379vduRwWfhsWfOHALiXn/K2vcFz5W3jyC3Dy1XDVfftHdpxzq++aAB+W95zu\nw/9fX91/Obe2BvjpPP+h9cnn/EWR75rlW/nX/L731wuHYf3jsPxeH7ppuf4DZP5iGHl8/7+Xw9Xe\n7LtVmmvgxMsgOc4D0PaWwL0L/O/sgm/F5luiDCnRhns0zZYVwDQzm2xmacC1wNIeZd4FLoi88Awg\nA6geWJWPrp8+X8yKHXv5z6tmxS7YAdY86Ft2F94O4+fvD3bwFwr+8P2AwR8/5oNhIPbu8B8OVevh\nww/A3I/D1Atg3o3w6s/80Le+tNb5g3z3nun7xC/+PnxulQ/PjDx44EPw1Ff8t4LdW+HXC31oXb3E\nX8HnoKu626FbvfNu8OGy9k/wu/f7YJ/zEX+wtUtKOiz6ge8CeeO+/etfusN3nVx6p3+NjDzfzbF+\nqe+m6M3fvgmP3ABN1f5A45fWwyU/HNxgB9/tNGkBzLwy/sEOvq/+0jsBO/jbkBxbnHP93oBLgc3A\nVuDrkXW3A1dEHs8EXgHeAtYAF/W3zblz57p4Wb61xk2+9Un3xYfejO2GW+qcu+N45/7nIufC4b7L\nbX7WuW/lOffopw5drrvyNc7dOc25709wruS1A59rbXDurpOd+8ls/7indY87d8dU/5qPf9a5hqoD\nn29vdu4vtzn3rWHO3X2qcz+Y5NwPJztXsjy6uvUlHHZu2df9dpdc7lxHW+9l7v+Qc98rcq5hl3NV\nm5z7zgjnHr/pwHINVc59d9TB651zbvUD/jWe/JJzodCR1TkomnbHuwYySICVLorc7rdbZrDEq1um\ntrmdS+7+O+kpSTz5+bPJSY9ha+uZ/4BXfwqLX+i/b/ilO+CF/weX3AmnLz502TcfhKe+BFkF8NFH\nD+zf7rLjFVhyGZz2CbjsR35dYzU8/WU/YmPMbHj/3f4gYF+2vQSP3+S/YVz/8OF1G/XkHGx51p8C\nn97HeQM1xfDzM2D2h303Udlq/60ip/DAck9/1Xf33PIW5I3z67qOb0xaANc/OjRazyKDKNpumWPu\nP+HWR9dS09jGY/+64NDB3t7sR4Fk5ke34d1bfX/vKddHd9Dv7K/4A3/LbvMH+6ZffnA3R2ebH4Gx\naokfjfKh3xwceF0mLYAzPwuv/QymX+YPeD79VX8Q9/z/gAW37O/T7suUc+CWNf5xf2WjZQYnXHTo\nMgVT4cyb4JW7/fKiH/a+n2d9zg+5fO1nsOj7vn/5oev9AeurlyjYRbqJ0fCQxPBqcQ1/XVfJFxae\n0P+0vX+41h/UqyuLbuPPfMP3IV/wzejKJyX5kRX5k+Hhj8Hdc+CFSGAB1O6E3yzywf7eL8JH/9x3\nsHc5/xtQcAL877Xw6CdgxGQ/iuZ9X4k+rJNTYxfsA/G+r0LuWD+q6LRP9l5m+AR/gHbVEv97+sN1\nfnjldQ9F/yEscow4ZrplnHN88N5Xqaxr5YWvnEtG6iHG/m7/O/zucv+46DT4+NOHPktx6/P+jMuF\n3/ZBPBAdrbDpKXjzAX92J8630ivf8WOxP3AvzLg8+u2VrfbBPvcG35JPpDHOjVX+AzLjEB+8VRvh\n56dD5gg/Jvz6R/xBZZFjRCxHywTC8xurePPdWj53/rRDBzvASz+EnNG+ZV26Apb9e99lQ51+DpH8\nyXDGTQOvWGoGzPqQH+v9hbfh3H/346HzJ8KnXhhYsIPvU//8m/60/kQKdvBTFhwq2MEfb5h+uR9v\nf/H3FewifTgmOinDYcedyzYxcWQWV88rOnThHf/wJxpd/H1/gk7lWt/HO34+zL7mwLINlX4+j+qN\n8OEHfavzSAyfAOf+m79J3y7/iT/4qotQiPTpmGi5P7W2go2VDXxx4QmkJvezyy/+wLfa50WuFr/w\n2/6U76Wfh13r9pdb/4Qf4bHjFT86ZaAtbDl8OYV+ThrNdijSp8CHe2cozF3PbuaE0Tm8f85xhy68\n4xXfal/wBT8cEPzBxat/60/n/uPHoK4U/vwZePiffVfMZ/7e9wFAEZE4CXy3zGOry9hW08QvPzaX\n5P4mBXupR6u9S+4YP9RuyeV+VItzcM6/+REe8RhZIiLSj0C33Ns6Q9z93BbmFOVx0czRhy5c8ips\nf9mPB+9qtXc38Sx/pZsxs+HGZXDevyvYRWTICnTL/aE3dlJW28L3P3gy1tU/W/w3SM32Qxy7n/Ty\n4g8ge5QfQtiXeTf6m4jIEBfYcG/tCPHT54s5ffIIzp5W4Fe+86ifihb8SS9TF8IJi/zc3Ntfgou/\nd/BFhkVEElBgw/2lzdXUNLbxo2vm+FZ7xdvw+Gdh/Bn+0mRbnvG3tZH5rvtrtYuIJJDAhvuz63cx\nLCOFs44f6edAf+gj/oIUH77fnywz64N+7vWy1b6rZvx8tdpFJDACGe6hsOP5jVWcN30UqYT8VYua\nquGGvxx44eakZBh/mr+JiARIIMN9Vcle9jS1c+HM0X5qgJJ/+KsBHWq6WxGRAAnkUMi/bdhFarKx\nsGWZvyjxmTfDnA/Hu1oiIkdN4MLdOcez63dx8UTIeOZr/lJjC78T72qJiBxVgQv3rdWNbK9p4urC\nMgi1w3nf0EUcROSYE7hwf2b9LgDmpm6HpFQYMyvONRIROfoCF+7Prt/FyePyyKl52wf7kU7DKyKS\ngAIV7lUNrazZWcuFMwqhfA0cp9ExInJsClS4P7ehCufgsuMaob0Bxs2Nd5VEROIiUOH+7PpdFOVn\nMqV9k1+hce0icowKTLg3tXXyj+IaLpw5Gitb7Wd+LDgh3tUSEYmLwIT737dU094Z9mellq+G405J\nvAtEi4jESGDC/Zn1u8jLTGX++Bx/UWt1yYjIMSwQ4d4ZCvP8xirOnz6KlJoN/uQljZQRkWNYIMJ9\nzc5aaps7fJdM2Sq/Ui13ETmGBSLct1Y3AjC7KA/K3oSskTB8YpxrJSISP4EI97K9LSQZjB6WETmY\n+h7oumaqiMgxKBjhXtvK6GEZpHY2Q/VGdcmIyDEvEOFeXtvCuOGZUPEWuLDOTBWRY14wwr2uheOG\nZ/ouGdBIGRE55iV8uIfDjoraVh/uZashbzzkFMa7WiIicRVVuJvZIjPbZGbFZnZrH2WuMbP1ZrbO\nzP43ttXsW01jG+2hMOOGZ/hhkOpvFxHp/wLZZpYM3ANcCJQCK8xsqXNufbcy04DbgAXOub1mNmqw\nKtxTWW0LAJOyWqC2BObdeLReWkRkyIqm5T4fKHbObXPOtQMPAVf2KPMp4B7n3F4A51xVbKvZt/La\nVgAmtW32K9RyFxGJKtzHATu7LZdG1nV3AnCCmb1iZsvNbFFvGzKzxWa20sxWVldXH16NeyirbQag\nsH4dYDD2lJhsV0QkkUUT7r2dDeR6LKcA04BzgeuA/zGz4Qf9kHP3OefmOefmFRbG5qBneW0ruekp\nZFS95af4zRgWk+2KiCSyaMK9FBjfbbkIKO+lzBPOuQ7n3HZgEz7sB11ZbUvkYOpqdcmIiEREE+4r\ngGlmNtnM0oBrgaU9yjwOnAdgZgX4bpptsaxoX8r2tnBSbiM0VWl8u4hIRL/h7pzrBG4GlgEbgIed\nc+vM7HYzuyJSbBmw28zWAy8AX3XO7R6sSndXXtfC3JTtfkEtdxERIIqhkADOuaeBp3us+2a3xw74\nUuR21DS1dVLb3MGJ4WJISoHRs47my4uIDFkJfYZqRZ0f4z6+ZROMmgmpGXGukYjI0JDQ4V66twVw\njKhbB8edGu/qiIgMGQkd7uW1rYy3KlLa69TfLiLSTYKHewunJkcOpqrlLiKyT1QHVIeq8toWzkgv\nAdKhcEa8qyMiMmQkdMu9tLaF2cnbYcwsSEmLd3VERIaMhA73ir1NTO0sVpeMiEgPCRvuobAjo34H\nGeFmhbuISA8JG+7VDW3MZKtfULiLiBwgYcO9rLaZ2UnbCSVnQMGJ8a6OiMiQksDh3srJSdtoL5wF\nyQk96EdEJOYSNtwr9jYyy3aQXKSTl0REekrYJm975UayrA3Gz4t3VUREhpyEbbln1az1D3QwVUTk\nIAkb7oX162i1TBg5Nd5VEREZchI23Ce1b6YiezokJewuiIgMmoRMxoamZk50O6jP18U5RER6k5Dh\nvnv726RbB51j5sS7KiIiQ1JChntryUoA0idopIyISG8SMtyTKtdQ77IonDA93lURERmSEjLcc/es\nZa2bQmGurpkqItKbxAv3zjYKm4rZnjaNpCSLd21ERIakxAv3XetIoZOqnJnxromIyJCVeOFe/iYA\nTQWz41wREZGhK+HmlgnlH89DoYVkFU6Kd1VERIashGu5Vxacztc7buS4/Kx4V0VEZMhKuHAvr20B\n4LjhmXGuiYjI0JWw4T5O4S4i0qeEC/fSvV0td41xFxHpS8IdUL3+9AksmFpAVlrCVV1E5KhJuIQc\nnpXGKVlp8a6GiMiQlnDdMiIi0j+Fu4hIACncRUQCSOEuIhJAUYW7mS0ys01mVmxmtx6i3D+ZmTMz\nXUVDRCSO+g13M0sG7gEuAWYC15nZQVMymlku8Hng9VhXUkREBiaalvt8oNg5t8051w48BFzZS7nv\nAncArTGsn4iIHIZown0csLPbcmlk3T5mdiow3jn35KE2ZGaLzWylma2srq4ecGVFRCQ60YR7b5c7\ncvueNEsC7gK+3N+GnHP3OefmOefmFRYWRl9LEREZkGjCvRQY3225CCjvtpwLzAJeNLMdwBnAUh1U\nFRGJn2jCfQUwzcwmm1kacC2wtOtJ51ydc67AOTfJOTcJWA5c4ZxbOSg1FhGRfvUb7s65TuBmYBmw\nAXjYObfOzG43sysGu4IiIjJwUU0c5px7Gni6x7pv9lH23COvloiIHAmdoSoiEkAKdxGRAFK4i4gE\nkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQCSOEuIhJACncRkQBSuIuIBJDCXUQkgBTuIiIBpHAX\nEQkghbuISAAp3EVEAkjhLiKcBTfmAAAHBklEQVQSQAp3EZEAUriLiASQwl1EJIAU7iIiAaRwFxEJ\nIIW7iEgAKdxFRAJI4S4iEkAKdxGRAFK4i4gEkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQCSOEu\nIhJACncRkQCKKtzNbJGZbTKzYjO7tZfnv2Rm683sbTN7zswmxr6qIiISrX7D3cySgXuAS4CZwHVm\nNrNHsTeBec652cAjwB2xrqiIiEQvmpb7fKDYObfNOdcOPARc2b2Ac+4F51xzZHE5UBTbaoqIyEBE\nE+7jgJ3dlksj6/ryCeAvvT1hZovNbKWZrayuro6+liIiMiDRhLv1ss71WtDso8A84M7ennfO3eec\nm+ecm1dYWBh9LUVEZEBSoihTCozvtlwElPcsZGYLga8D5zjn2mJTPRERORzRtNxXANPMbLKZpQHX\nAku7FzCzU4FfAlc456piX00RERmIfsPdOdcJ3AwsAzYADzvn1pnZ7WZ2RaTYnUAO8CczW2NmS/vY\nnIiIHAXRdMvgnHsaeLrHum92e7wwxvUSEZEjoDNURUQCSOEuIhJACncRkQBSuIuIBJDCXUQkgBTu\nIiIBpHAXEQkghbuISAAp3EVEAkjhLiISQAp3EZEAUriLiASQwl1EJIAU7iIiAaRwFxEJIIW7iEgA\nKdxFRAJI4S4iEkAKdxGRAFK4i4gEkMJdRCSAFO4iIgGkcBcRCSCFu4hIACncRUQCSOEuIhJACncR\nkQBSuIuIBJDCXUQkgBTuIiIBpHAXEQkghbuISAAp3EVEAkjhLiISQAp3EZEAiirczWyRmW0ys2Iz\nu7WX59PN7I+R5183s0mxrqiIiESv33A3s2TgHuASYCZwnZnN7FHsE8Be59xU4C7gh7GuqIiIRC+a\nlvt8oNg5t8051w48BFzZo8yVwO8ijx8BLjAzi101RURkIFKiKDMO2NltuRQ4va8yzrlOM6sDRgI1\n3QuZ2WJgcWSx0cw2HU6lgYKe205wQdqfIO0LaH+GsiDtC0S/PxOj2Vg04d5bC9wdRhmcc/cB90Xx\nmoeukNlK59y8I93OUBGk/QnSvoD2ZygL0r5A7Pcnmm6ZUmB8t+UioLyvMmaWAuQBe2JRQRERGbho\nwn0FMM3MJptZGnAtsLRHmaXAv0Qe/xPwvHPuoJa7iIgcHf12y0T60G8GlgHJwG+cc+vM7HZgpXNu\nKfBr4H4zK8a32K8dzEoTg66dISZI+xOkfQHtz1AWpH2BGO+PqYEtIhI8OkNVRCSAFO4iIgGUcOHe\n31QIQ52Z/cbMqszsnW7rRpjZs2a2JXKfH886RsvMxpvZC2a2wczWmdktkfWJuj8ZZvaGmb0V2Z/v\nRNZPjkyrsSUyzUZavOsaLTNLNrM3zezJyHIi78sOM1trZmvMbGVkXaK+14ab2SNmtjHy/3NmrPcl\nocI9yqkQhrolwKIe624FnnPOTQOeiywngk7gy865GcAZwGcjf49E3Z824Hzn3BzgFGCRmZ2Bn07j\nrsj+7MVPt5EobgE2dFtO5H0BOM85d0q38eCJ+l67G/irc246MAf/N4rtvjjnEuYGnAks67Z8G3Bb\nvOt1GPsxCXin2/ImYGzk8VhgU7zreJj79QRwYRD2B8gCVuPPxq4BUiLrD3gPDuUb/pyU54DzgSfx\nJxsm5L5E6rsDKOixLuHea8AwYDuRAS2DtS8J1XKn96kQxsWpLrE02jlXARC5HxXn+gxYZCbQU4HX\nSeD9iXRjrAGqgGeBrUCtc64zUiSR3nM/Ab4GhCPLI0ncfQF/1vszZrYqMpUJJOZ7bQpQDfw20mX2\nP2aWTYz3JdHCPappDuToMrMc4FHgC865+njX50g450LOuVPwrd75wIzeih3dWg2cmV0OVDnnVnVf\n3UvRIb8v3Sxwzr0H3y37WTN7X7wrdJhSgPcA9zrnTgWaGITupEQL92imQkhEu8xsLEDkvirO9Yma\nmaXig/1B59xjkdUJuz9dnHO1wIv4YwnDI9NqQOK85xYAV5jZDvxMrufjW/KJuC8AOOfKI/dVwJ/x\nH76J+F4rBUqdc69Hlh/Bh31M9yXRwj2aqRASUffpG/4F33c95EWmdf41sME59+NuTyXq/hSa2fDI\n40xgIf5A1wv4aTUgQfbHOXebc67IOTcJ/3/yvHPuehJwXwDMLNvMcrseAxcB75CA7zXnXCWw08xO\njKy6AFhPrPcl3gcXDuNgxKXAZnxf6NfjXZ/DqP8fgAqgA/8J/gl8X+hzwJbI/Yh41zPKfXkv/mv9\n28CayO3SBN6f2cCbkf15B/hmZP0U4A2gGPgTkB7vug5wv84FnkzkfYnU+63IbV3X/34Cv9dOAVZG\n3muPA/mx3hdNPyAiEkCJ1i0jIiJRULiLiASQwl1EJIAU7iIiAaRwFxEJIIW7iEgAKdxFRALo/wPa\npGOmJyB8KgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f56e46957f0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "epochs = 60\n",
    "\n",
    "a, b = [], []\n",
    "for epoch in range(epochs):\n",
    "    if epoch  == 40:\n",
    "        trainer.set_learning_rate(0.01)\n",
    "    train_loss = 0.\n",
    "    train_acc = 0.\n",
    "    start = time()\n",
    "    for data, label in train_data:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            l = loss(output, label)\n",
    "        l.backward()\n",
    "        trainer.step(batch_size)\n",
    "        train_loss = train_loss + nd.mean(l).asscalar()\n",
    "        train_acc = train_acc + accuracy(output, label)\n",
    "    test_acc = evaluate_accuracy(test_data, net, ctx)\n",
    "    \n",
    "    if epoch%10 == 0:\n",
    "        print(epoch, 'loss:%.4f tracc:%.4f teacc:%.4f time:%.3f'%(\n",
    "            train_loss/len(train_data), train_acc/len(train_data), test_acc, time()-start)) \n",
    "    a.append(train_acc/len(train_data))\n",
    "    b.append(test_acc)\n",
    "\n",
    "print('tracc:%f teacc:%f'%(train_acc/len(train_data), test_acc))\n",
    "plt.plot(np.arange(epochs), a, np.arange(epochs), b)\n",
    "plt.ylim(0,1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
